#include <stdlib.h>
#include <assert.h>
#include <math.h>

#include "io_psf.h"
#include "io_charmm_inp.h"
#include "esmd_types.h"
#include "cell.h"
#include "bonded.h"
#include "memory_cpp.hpp"
#define CMP_INT(a, b) (((a) > (b)) - ((a) < (b)))

static int cmp_temp_tori_param(const void *va, const void *vb) {
  const temp_dihed_param_t *a = (const temp_dihed_param_t *)va;
  const temp_dihed_param_t *b = (const temp_dihed_param_t *)vb;

  int a_wild = (a->itype == ELEMTAB_WILD && a->ltype == ELEMTAB_WILD);
  int b_wild = (b->itype == ELEMTAB_WILD && b->ltype == ELEMTAB_WILD);
  //put wildcard params at last
  return CMP_INT(a_wild, b_wild) * 81 + CMP_INT(a->jtype, b->jtype) * 27 + CMP_INT(a->ktype, b->ktype) * 9 + CMP_INT(a->itype, b->itype) * 3 + CMP_INT(a->ltype, b->ltype);
}

static int cmp_temp_impr_param(const void *va, const void *vb) {
  const temp_dihed_param_t *a = (const temp_dihed_param_t *)va;
  const temp_dihed_param_t *b = (const temp_dihed_param_t *)vb;
  int a_wild = impr_wildcard_level(a);
  int b_wild = impr_wildcard_level(b);
  //put wildcards at first
  return CMP_INT(b_wild, a_wild) * 81 + CMP_INT(a->itype, b->itype) * 27 + CMP_INT(a->jtype, b->jtype) * 9 + CMP_INT(a->ktype, b->ktype) * 3 + CMP_INT(a->ltype, b->ltype);
}

bond_param_t *build_bond_params(charmm_rawparam_t *inp, elemtab_t *elemtab) {
  int ntypes = elemtab->cnt;
  bond_param_t *bond_param = esmd::allocate(ntypes * ntypes, "params/bond", bond_param_t{0, 0, 0});
  for (int i = 0; i < inp->nbond; i++) {
    typeint itype = elemtab_get(elemtab, inp->bond[i].itype);
    typeint jtype = elemtab_get(elemtab, inp->bond[i].jtype);
    if (itype != ELEMTAB_NONE && jtype != ELEMTAB_NONE) {
      int ifwd = itype * ntypes + jtype;
      int irev = jtype * ntypes + itype;
      bond_param[ifwd].kr = inp->bond[i].kr;
      bond_param[ifwd].r0 = inp->bond[i].r0;
      bond_param[ifwd].shaked = 0;
      bond_param[irev].kr = inp->bond[i].kr;
      bond_param[irev].r0 = inp->bond[i].r0;
      bond_param[irev].shaked = 0;
    }
  }
  return bond_param;
}

angle_param_t *build_angle_params(charmm_rawparam_t *inp, elemtab_t *elemtab) {
  int ntypes = elemtab->cnt;
  angle_param_t *angle_param = esmd::allocate(ntypes * ntypes * ntypes, "params/angle", angle_param_t{0, 0, 0, 0, 0});
  for (int i = 0; i < inp->nangle; i++) {
    typeint itype = elemtab_get(elemtab, inp->angle[i].itype);
    typeint jtype = elemtab_get(elemtab, inp->angle[i].jtype);
    typeint ktype = elemtab_get(elemtab, inp->angle[i].ktype);
    if (itype != ELEMTAB_NONE && jtype != ELEMTAB_NONE && ktype != ELEMTAB_NONE) {
      int ifwd = (jtype * ntypes + itype) * ntypes + ktype;
      int irev = (jtype * ntypes + ktype) * ntypes + itype;
      angle_param[ifwd].kub = inp->angle[i].kub;
      angle_param[ifwd].r0ub = inp->angle[i].s0;
      angle_param[ifwd].ktheta = inp->angle[i].ktheta;
      angle_param[ifwd].theta0 = inp->angle[i].theta0 * M_PI / 180.;
      angle_param[ifwd].shaked = 0;
      angle_param[irev].kub = inp->angle[i].kub;
      angle_param[irev].r0ub = inp->angle[i].s0;
      angle_param[irev].ktheta = inp->angle[i].ktheta;
      angle_param[irev].theta0 = inp->angle[i].theta0 * M_PI / 180.;
      angle_param[irev].shaked = 0;
    }
  }
  return angle_param;
}

void build_tori_params_tab(tori_param_tab_t *tab, charmm_rawparam_t *inp, elemtab_t *elemtab) {
  int ntypes = elemtab->cnt;
  elemtab_set_wild(elemtab, "X");
  int tot_dihed_param = ntypes * ntypes * ntypes * ntypes;
  temp_dihed_param_t *tmp_tori = esmd::allocate(inp->ntori * 2, "params/torison temp", temp_dihed_param_t{0, 0, 0, 0, 0, 0, 0});
  int ntmp_tori = 0;
  for (int i = 0; i < inp->ntori; i++) {
    typeint itype = elemtab_get(elemtab, inp->tori[i].itype);
    typeint jtype = elemtab_get(elemtab, inp->tori[i].jtype);
    typeint ktype = elemtab_get(elemtab, inp->tori[i].ktype);
    typeint ltype = elemtab_get(elemtab, inp->tori[i].ltype);

    if (itype != ELEMTAB_NONE && jtype != ELEMTAB_NONE && ktype != ELEMTAB_NONE && ltype != ELEMTAB_NONE) {
      tmp_tori[ntmp_tori].itype = itype;
      tmp_tori[ntmp_tori].jtype = jtype;
      tmp_tori[ntmp_tori].ktype = ktype;
      tmp_tori[ntmp_tori].ltype = ltype;
      tmp_tori[ntmp_tori].phi0 = inp->tori[i].phi0 * M_PI / 180.;
      tmp_tori[ntmp_tori].kphi = inp->tori[i].kphi;
      tmp_tori[ntmp_tori].n = inp->tori[i].n;
      ntmp_tori++;
    }
  }
  int ntmp_tori_di = ntmp_tori;
  for (int i = 0; i < ntmp_tori; i++) {
    if (tmp_tori[i].itype != tmp_tori[i].ltype || tmp_tori[i].jtype != tmp_tori[i].ktype) {
      tmp_tori[ntmp_tori_di].itype = tmp_tori[i].ltype;
      tmp_tori[ntmp_tori_di].jtype = tmp_tori[i].ktype;
      tmp_tori[ntmp_tori_di].ktype = tmp_tori[i].jtype;
      tmp_tori[ntmp_tori_di].ltype = tmp_tori[i].itype;
      tmp_tori[ntmp_tori_di].phi0 = tmp_tori[i].phi0;
      tmp_tori[ntmp_tori_di].kphi = tmp_tori[i].kphi;
      tmp_tori[ntmp_tori_di].n = tmp_tori[i].n;
      ntmp_tori_di++;
    }
  }
  ntmp_tori = ntmp_tori_di;
  qsort(tmp_tori, ntmp_tori, sizeof(temp_dihed_param_t), cmp_temp_tori_param);

  tori_param_range_t *range = esmd::allocate(tot_dihed_param, "params/torsion/range", tori_param_range_t{0, 0});
  for (int i = 0; i < tot_dihed_param; i++)
    range[i].count = 0;
  int n_tori_nowild = 0;
  for (int i = 0; i < ntmp_tori; i++) {
    if (tmp_tori[i].itype != ELEMTAB_WILD && tmp_tori[i].ltype != ELEMTAB_WILD) {
      int itor = htori_type(ntypes, tmp_tori[i].itype, tmp_tori[i].jtype, tmp_tori[i].ktype, tmp_tori[i].ltype);
      if (range[itor].count == 0)
        range[itor].start = i;
      range[itor].count++;
      n_tori_nowild = i;
    } else {
      for (int iwc = 0; iwc < ntypes; iwc++)
        for (int lwc = 0; lwc < ntypes; lwc++) {
          int itor = htori_type(ntypes, iwc, tmp_tori[i].jtype, tmp_tori[i].ktype, lwc);
          if (range[itor].count == 0)
            range[itor].start = i;
          if (range[itor].start > n_tori_nowild)
            range[itor].count++;
        }
    }
  }
  tori_param_t *tori_params = esmd::allocate(ntmp_tori, "params/torsion/param", tori_param_t{0, 0});
  for (int i = 0; i < ntmp_tori; i++) {
    tori_params[i].phi0 = tmp_tori[i].phi0;
    tori_params[i].kphi = tmp_tori[i].kphi;
    tori_params[i].n = tmp_tori[i].n;
  }
  esmd::deallocate(tmp_tori);
  tab->param = tori_params;
  tab->range = range;
  elemtab_unset_wild(elemtab, "X");
}

harmonic_impr_param_t *build_impr_params(charmm_rawparam_t *inp, elemtab_t *elemtab) {
  int ntypes = elemtab->cnt;
  //printf("%d\n", ntypes);
  int tot_dihed_param = ntypes * ntypes * ntypes * ntypes;
  elemtab_set_wild(elemtab, "X");
  temp_dihed_param_t *tmp_impr = esmd::allocate<temp_dihed_param_t>(inp->nimpr * 2, "params/improper temp");
  int ntmp_impr = 0;
  for (int i = 0; i < inp->nimpr; i++) {
    typeint itype = elemtab_get(elemtab, inp->impr[i].itype);
    typeint jtype = elemtab_get(elemtab, inp->impr[i].jtype);
    typeint ktype = elemtab_get(elemtab, inp->impr[i].ktype);
    typeint ltype = elemtab_get(elemtab, inp->impr[i].ltype);

    if (itype != ELEMTAB_NONE && jtype != ELEMTAB_NONE && ktype != ELEMTAB_NONE && ltype != ELEMTAB_NONE) {
      tmp_impr[ntmp_impr].itype = itype;
      tmp_impr[ntmp_impr].jtype = jtype;
      tmp_impr[ntmp_impr].ktype = ktype;
      tmp_impr[ntmp_impr].ltype = ltype;
      tmp_impr[ntmp_impr].phi0 = inp->impr[i].phi0 * M_PI / 180.;
      tmp_impr[ntmp_impr].kphi = inp->impr[i].kphi;
      tmp_impr[ntmp_impr].n = inp->impr[i].n;
      ntmp_impr++;
    }
  }
  int ntmp_impr_di = ntmp_impr;
  for (int i = 0; i < ntmp_impr; i++) {
    if (tmp_impr[i].itype != tmp_impr[i].ltype || tmp_impr[i].jtype != tmp_impr[i].ktype) {
      tmp_impr[ntmp_impr_di].itype = tmp_impr[i].ltype;
      tmp_impr[ntmp_impr_di].jtype = tmp_impr[i].ktype;
      tmp_impr[ntmp_impr_di].ktype = tmp_impr[i].jtype;
      tmp_impr[ntmp_impr_di].ltype = tmp_impr[i].itype;
      tmp_impr[ntmp_impr_di].phi0 = tmp_impr[i].phi0;
      tmp_impr[ntmp_impr_di].kphi = tmp_impr[i].kphi;
      tmp_impr[ntmp_impr_di].n = tmp_impr[i].n;
      ntmp_impr_di++;
    }
  }
  ntmp_impr = ntmp_impr_di;
  qsort(tmp_impr, ntmp_impr, sizeof(temp_dihed_param_t), cmp_temp_impr_param);
  harmonic_impr_param_t *impr_params = esmd::allocate<harmonic_impr_param_t>(tot_dihed_param, "params/improper");
  for (int i = 0; i < ntmp_impr; i++) {
    typeint itype = tmp_impr[i].itype;
    typeint jtype = tmp_impr[i].jtype;
    typeint ktype = tmp_impr[i].ktype;
    typeint ltype = tmp_impr[i].ltype;
    typeint imin = itype, imax = itype + 1;
    typeint jmin = jtype, jmax = jtype + 1;
    typeint kmin = ktype, kmax = ktype + 1;
    typeint lmin = ltype, lmax = ltype + 1;
    if (itype == ELEMTAB_WILD) {
      imin = 0;
      imax = ntypes;
    }
    if (jtype == ELEMTAB_WILD) {
      jmin = 0;
      jmax = ntypes;
    }
    if (ktype == ELEMTAB_WILD) {
      kmin = 0;
      kmax = ntypes;
    }
    if (ltype == ELEMTAB_WILD) {
      lmin = 0;
      lmax = ntypes;
    }
    for (int iwc = imin; iwc < imax; iwc++) {
      for (int jwc = jmin; jwc < jmax; jwc++) {
        for (int kwc = kmin; kwc < kmax; kwc++) {
          for (int lwc = lmin; lwc < lmax; lwc++) {
            int iimp = himpr_type(ntypes, iwc, jwc, kwc, lwc);
            impr_params[iimp].phi0 = tmp_impr[i].phi0;
            impr_params[iimp].kphi = tmp_impr[i].kphi;
          }
        }
      }
    }
  }
  esmd::deallocate(tmp_impr);
  elemtab_unset_wild(elemtab, "X");
  return impr_params;
}

void build_nonbonded_param(nonbonded_param_t *param, real rcut, real rsw, real coul_const, real scale14, charmm_rawparam_t *inp, elemtab_t *elemtab) {
  param->rsw = rsw;
  param->rcut = rcut;
  param->nonb_threshold = 10000;
  param->coul_const = coul_const;
  param->scale14 = scale14;
  param->ntypes = elemtab->cnt;
  int ntypes = elemtab->cnt;
  param->vdw_param = esmd::allocate<vdw_param_t>(ntypes, "params/nonbonded");
  for (int i = 0; i < inp->nnonb; i ++) {
    int itype = elemtab_get(elemtab, inp->nonb[i].itype);
    if (itype != ELEMTAB_NONE) {
      param->vdw_param[itype].eps = sqrt(-inp->nonb[i].emin);
      param->vdw_param[itype].sig = inp->nonb[i].rmin;
      param->vdw_param[itype].eps14 = sqrt(-inp->nonb[i].emin14);
      param->vdw_param[itype].sig14 = inp->nonb[i].rmin14;
    }
  }
}

void build_param(charmm_param_t *param, charmm_rawparam_t *inp, elemtab_t *elemtab, real rcut, real rsw, real coul_const, real scale14){
  build_nonbonded_param(&param->nonb, rcut, rsw, coul_const, scale14, inp, elemtab);
  param->bond = build_bond_params(inp, elemtab);
  param->angle = build_angle_params(inp, elemtab);
  build_tori_params_tab(&param->tori, inp, elemtab);
  param->impr = build_impr_params(inp, elemtab);
  param->ntypes = elemtab->cnt;
}